from __future__ import print_function

import ast
import random
import time
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter

from . import gf
from . import hope
from . import lap
from . import line
from . import lle
from . import node2vec
from . import sdne
from . import tadw
from .graph import *
from .grarep import GraRep


def parse_args():
    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter,
                            conflict_handler='resolve')
    parser.add_argument('--dataset', nargs='?', default='clf/brazil-flights',
                        help='Input graph path')

    parser.add_argument('--number-walks', default=10, type=int,
                        help='Number of random walks to start at each node')
    parser.add_argument('--directed', default=False,
                        help='Treat graph as directed.')
    parser.add_argument('--walk-length', default=80, type=int,
                        help='Length of the random walk started at each node')
    parser.add_argument('--workers', default=8, type=int,
                        help='Number of parallel processes.')
    parser.add_argument('--representation-size', default=128, type=int,
                        help='Number of latent dimensions to learn for each node.')
    parser.add_argument('--window-size', default=10, type=int,
                        help='Window size of skipgram model.')
    parser.add_argument('--epochs', default=5, type=int,
                        help='The training epochs of LINE and GCN')
    parser.add_argument('--p', default=1.0, type=float)
    parser.add_argument('--q', default=1.0, type=float)
    parser.add_argument('--method', required=True, choices=[
        'node2vec',
        'deepwalk',
        'line',
        'gcn',
        'grarep',
        'tadw',
        'lle',
        'hope',
        'lap',
        'gf',
        'sdne'
    ], help='The learning method')
    parser.add_argument('--label-file', default='',
                        help='The file of node label')
    parser.add_argument('--feature-file', default='',
                        help='The file of node features')
    parser.add_argument('--graph-format', default='edgelist', choices=['adjlist', 'edgelist'],
                        help='Input graph format')
    parser.add_argument('--negative-ratio', default=5, type=int,
                        help='the negative ratio of LINE')
    parser.add_argument('--weighted', default=False,
                        help='Treat graph as weighted')
    parser.add_argument('--clf-ratio', default=0.5, type=float,
                        help='The ratio of training data in the classification')
    parser.add_argument('--order', default=3, type=int,
                        help='Choose the order of LINE, 1 means first order, 2 means second '
                             'order, 3 means first order + second order')
    parser.add_argument('--no-auto-save', action='store_true',
                        help='no save the best embeddings when training LINE')
    parser.add_argument('--dropout', default=0.5, type=float,
                        help='Dropout rate (1 - keep probability)')
    parser.add_argument('--weight-decay', type=float, default=5e-4,
                        help='Weight for L2 loss on embedding matrix')
    parser.add_argument('--hidden', default=16, type=int,
                        help='Number of units in hidden layer 1')
    parser.add_argument('--kstep', default=4, type=int,
                        help='Use k-step transition probability matrix')
    parser.add_argument('--lamb', default=0.2, type=float,
                        help='lambda is a hyperparameter in TADW')
    parser.add_argument('--lr', default=0.01, type=float,
                        help='learning rate')
    parser.add_argument('--alpha', default=1e-6, type=float,
                        help='alhpa is a hyperparameter in SDNE')
    parser.add_argument('--beta', default=5., type=float,
                        help='beta is a hyperparameter in SDNE')
    parser.add_argument('--nu1', default=1e-5, type=float,
                        help='nu1 is a hyperparameter in SDNE')
    parser.add_argument('--nu2', default=1e-4, type=float,
                        help='nu2 is a hyperparameter in SDNE')
    parser.add_argument('--bs', default=200, type=int,
                        help='batch size of SDNE')
    parser.add_argument('--Fmethod', nargs='?', default='ReFeX', help='Method')
    parser.add_argument('--sp', nargs='?', default=0.7, type=float, help='split ratio')
    parser.add_argument('--lpmethod', nargs='?', default='Hadamard', help='binary operator')
    parser.add_argument('--encoder-list', default='[1000, 128]', type=str,
                        help='a list of numbers of the neuron at each encoder layer, '
                             'the last number is the '
                             'dimension of the output node representation')
    args = parser.parse_args()

    return args


def tasks(args, embed):
    if args.dataset[0:3] == "clf":
        task = Task('CLF')
        label = np.loadtxt("../dataset/" + args.dataset + '.lbl', dtype=np.int)
        return task.classfication(embed, label, split_ratio=args.sp, loop=50)
    elif args.dataset[0:2] == "lp":
        task = Task('LP')
        datafile = open('../cache/{}-1.pkl'.format(args.dataset), 'rb')
        graphAttr = pkl.load(datafile)
        datafile.close()
        edgeRmdList = graphAttr['edge_rmd']
        edgeRmvdList = graphAttr['edge_rmvd']
        negSampleList_train = graphAttr['edge_train_neg']
        negSampleList_test = graphAttr['edge_test_neg']
        test_sample = edgeRmvdList
        test_label = [1 for i in range(len(edgeRmvdList))]
        test_neg_label = [0 for i in range(len(negSampleList_test))]
        test_label.extend(test_neg_label)
        test_sample.extend(negSampleList_test)
        tmp_test = list(zip(test_label, test_sample))
        random.shuffle(tmp_test)
        test_label[:], test_sample[:] = zip(*tmp_test)
        test_label = np.array(test_label)
        return task.link_prediction(embed, test_sample, test_label, split_ratio=args.sp,
                                    method=args.lpmethod, loop=1)


def main(args):
    g = Graph()
    print("Reading...")
    input = "temp/dataset/" + args.dataset + ".edge"
    output = "temp/embed/" + args.dataset + '.emb'
    if args.graph_format == 'adjlist':
        g.read_adjlist(filename=input)
    elif args.graph_format == 'edgelist':
        g.read_edgelist(filename=input, weighted=args.weighted,
                        directed=args.directed)
    t1 = time.time()
    print("Start time: {}".format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(t1
                                                                                    ))))
    if args.method == 'node2vec':
        model = node2vec.Node2vec(graph=g, path_length=args.walk_length,
                                  num_paths=args.number_walks, dim=args.representation_size,
                                  workers=args.workers, p=args.p, q=args.q,
                                  window=args.window_size)
    elif args.method == 'line':
        if args.label_file and not args.no_auto_save:
            model = line.LINE(g, epoch=args.epochs, rep_size=args.representation_size,
                              order=args.order,
                              label_file=args.label_file, clf_ratio=args.clf_ratio)
        else:
            model = line.LINE(g, epoch=args.epochs,
                              rep_size=args.representation_size, order=args.order)
    elif args.method == 'deepwalk':
        model = node2vec.Node2vec(graph=g, path_length=args.walk_length,
                                  num_paths=args.number_walks, dim=args.representation_size,
                                  workers=args.workers, window=args.window_size, dw=True)
    elif args.method == 'tadw':
        # assert args.label_file != ''
        assert args.feature_file != ''
        g.read_node_label(args.label_file)
        g.read_node_features(args.feature_file)
        model = tadw.TADW(
            graph=g, dim=args.representation_size, lamb=args.lamb)
    elif args.method == 'grarep':
        model = GraRep(graph=g, Kstep=args.kstep, dim=args.representation_size)
    elif args.method == 'lle':
        model = lle.LLE(graph=g, d=args.representation_size)
    elif args.method == 'hope':
        model = hope.HOPE(graph=g, d=args.representation_size)
    elif args.method == 'sdne':
        encoder_layer_list = ast.literal_eval(args.encoder_list)
        model = sdne.SDNE(g, encoder_layer_list=encoder_layer_list,
                          alpha=args.alpha, beta=args.beta, nu1=args.nu1, nu2=args.nu2,
                          batch_size=args.bs, epoch=args.epochs, learning_rate=args.lr)
    elif args.method == 'lap':
        model = lap.LaplacianEigenmaps(g, rep_size=args.representation_size)
    elif args.method == 'gf':
        model = gf.GraphFactorization(g, rep_size=args.representation_size,
                                      epoch=args.epochs, learning_rate=args.lr,
                                      weight_decay=args.weight_decay)
    t2 = time.time()
    print("Embedding time: {}".format(t2 - t1))
    model.save_embeddings(output)


if __name__ == "__main__":
    random.seed(32)
    np.random.seed(32)
    main(parse_args())
